{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71a6dc70",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install -Uqqq datasets openinference-semantic-conventions openinference-instrumentation-openai faker openai-responses openai tiktoken \"openinference-instrumentation>=0.1.38\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52b788701385cc6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from contextlib import ExitStack, contextmanager\n",
    "from random import choice, choices, randint, random, shuffle\n",
    "from uuid import uuid4\n",
    "\n",
    "import numpy as np\n",
    "import openai\n",
    "import pandas as pd\n",
    "from datasets import load_dataset\n",
    "from faker import Faker\n",
    "from openai_responses import OpenAIMock\n",
    "from openinference.instrumentation import dangerously_using_project, using_session, using_user\n",
    "from openinference.instrumentation.openai import OpenAIInstrumentor\n",
    "from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes\n",
    "from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter\n",
    "from opentelemetry.sdk.trace import SpanLimits, StatusCode, TracerProvider\n",
    "from opentelemetry.sdk.trace.export import SimpleSpanProcessor\n",
    "from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter\n",
    "from opentelemetry.trace import format_span_id\n",
    "from tiktoken import encoding_for_model\n",
    "\n",
    "import phoenix as px\n",
    "from phoenix.trace.span_evaluations import SpanEvaluations\n",
    "\n",
    "fake = Faker([\"ja_JP\", \"vi_VN\", \"ko_KR\", \"zh_CN\", \"th_TH\", \"bn_BD\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b9e29a5",
   "metadata": {},
   "source": [
    "# Download Data\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce04896a",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = \"GitBag/ultrainteract_multiturn_1_iter_processed_harvard\"\n",
    "convo = load_dataset(path, split=\"test\").to_pandas().chosen"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c62ce0f",
   "metadata": {},
   "source": [
    "# Tracer Provider\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d87e76da34fee68b",
   "metadata": {},
   "outputs": [],
   "source": [
    "tracer_provider = TracerProvider(span_limits=SpanLimits(max_attributes=1_000_000))\n",
    "in_memory_span_exporter = InMemorySpanExporter()\n",
    "tracer_provider.add_span_processor(SimpleSpanProcessor(in_memory_span_exporter))\n",
    "endpoint = \"http://127.0.0.1:4317\"\n",
    "otlp_span_exporter = OTLPSpanExporter(endpoint=endpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df1cf375",
   "metadata": {},
   "source": [
    "# Helpers\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f60b2ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_session_id():\n",
    "    p = random()\n",
    "    if p < 0.1:\n",
    "        return \":\" * randint(1, 5)\n",
    "    if p < 0.9:\n",
    "        return fake.address()\n",
    "    return int(abs(random()) * 1_000_000_000)\n",
    "\n",
    "\n",
    "def gen_user_id():\n",
    "    p = random()\n",
    "    if p < 0.1:\n",
    "        return \":\" * randint(1, 5)\n",
    "    if p < 0.9:\n",
    "        return fake.name()\n",
    "    return int(abs(random()) * 1_000_000_000)\n",
    "\n",
    "\n",
    "def export_spans(prob_drop_root):\n",
    "    \"\"\"Export spans in random order for receiver testing\"\"\"\n",
    "    spans = list(in_memory_span_exporter.get_finished_spans())\n",
    "    shuffle(spans)\n",
    "    for span in spans:\n",
    "        if span.parent is None and random() < prob_drop_root:\n",
    "            continue\n",
    "        otlp_span_exporter.export([span])\n",
    "    in_memory_span_exporter.clear()\n",
    "    session_count = len({id_ for span in spans if (id_ := span.attributes.get(\"session.id\"))})\n",
    "    trace_count = len({span.context.trace_id for span in spans})\n",
    "    print(f\"Exported {session_count} sessions, {trace_count} traces, {len(spans)} spans\")\n",
    "    return spans\n",
    "\n",
    "\n",
    "def rand_span_kind():\n",
    "    yield SpanAttributes.OPENINFERENCE_SPAN_KIND, choice(list(OpenInferenceSpanKindValues)).value\n",
    "\n",
    "\n",
    "def rand_status_code():\n",
    "    return choices(\n",
    "        [StatusCode.OK, StatusCode.ERROR, StatusCode.UNSET], k=1, weights=[0.98, 0.01, 0.01]\n",
    "    )[0]\n",
    "\n",
    "\n",
    "@contextmanager\n",
    "def trace_tree(tracer, n=5):\n",
    "    if n <= 0:\n",
    "        yield\n",
    "        return\n",
    "    has_yielded = False\n",
    "    with tracer.start_as_current_span(\n",
    "        fake.city(),\n",
    "        attributes=dict(rand_span_kind()),\n",
    "        end_on_exit=False,\n",
    "    ) as root:\n",
    "        for _ in range(randint(0, n)):\n",
    "            with trace_tree(tracer, randint(0, n - 1)):\n",
    "                if not has_yielded and random() < 0.5:\n",
    "                    yield\n",
    "                    has_yielded = True\n",
    "                else:\n",
    "                    pass\n",
    "        if not has_yielded:\n",
    "            yield\n",
    "            has_yielded = True\n",
    "        for _ in range(randint(0, n)):\n",
    "            with trace_tree(tracer, randint(0, n - 1)):\n",
    "                pass\n",
    "    root.set_status(rand_status_code())\n",
    "    root.end(int(fake.future_datetime(\"+5s\").timestamp() * 10**9))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b922a0c7",
   "metadata": {},
   "source": [
    "# Generate Sessions (For Demos)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98a91087",
   "metadata": {},
   "outputs": [],
   "source": [
    "session_count = randint(5, 10)\n",
    "project_name = \"Sessions-Fixtures\"\n",
    "\n",
    "\n",
    "def simulate_openai():\n",
    "    user_id = Faker().user_name()\n",
    "    session_id = str(uuid4())\n",
    "    client = openai.Client(api_key=\"sk-\")\n",
    "    model = \"gpt-4o-mini\"\n",
    "    encoding = encoding_for_model(model)\n",
    "    messages = np.concatenate(convo.sample(randint(1, 10)).values)\n",
    "    counts = [len(encoding.encode(m[\"content\"])) for m in messages]\n",
    "    openai_mock = OpenAIMock()\n",
    "    tracer = tracer_provider.get_tracer(__name__)\n",
    "    with openai_mock.router:\n",
    "        for i in range(1, len(messages), 2):\n",
    "            openai_mock.chat.completions.create.response = dict(\n",
    "                choices=[dict(index=0, finish_reason=\"stop\", message=messages[i])],\n",
    "                usage=dict(\n",
    "                    prompt_tokens=sum(counts[:i]),\n",
    "                    completion_tokens=counts[i],\n",
    "                    total_tokens=sum(counts[: i + 1]),\n",
    "                ),\n",
    "            )\n",
    "            with ExitStack() as stack:\n",
    "                attributes = {\n",
    "                    \"input.value\": messages[i - 1][\"content\"],\n",
    "                    \"output.value\": messages[i][\"content\"],\n",
    "                    \"session.id\": session_id,\n",
    "                    \"user.id\": user_id,\n",
    "                }\n",
    "                root = stack.enter_context(\n",
    "                    tracer.start_as_current_span(\n",
    "                        \"root\",\n",
    "                        attributes=attributes,\n",
    "                    )\n",
    "                )\n",
    "                client.chat.completions.create(model=model, messages=messages[:i])\n",
    "                root.set_status(StatusCode.OK)\n",
    "\n",
    "\n",
    "OpenAIInstrumentor().instrument(tracer_provider=tracer_provider)\n",
    "try:\n",
    "    with dangerously_using_project(project_name):\n",
    "        for _ in range(session_count):\n",
    "            simulate_openai()\n",
    "finally:\n",
    "    OpenAIInstrumentor().uninstrument()\n",
    "spans = export_spans(0)\n",
    "\n",
    "# Annotate root spans\n",
    "root_span_ids = pd.Series(\n",
    "    [format_span_id(span.context.span_id) for span in spans if span.parent is None]\n",
    ")\n",
    "for name in [\"Helpfulness\", \"Relevance\", \"Engagement\"]:\n",
    "    span_ids = root_span_ids.sample(frac=0.5)\n",
    "    df = pd.DataFrame(\n",
    "        {\n",
    "            \"context.span_id\": span_ids,\n",
    "            \"score\": np.random.rand(len(span_ids)),\n",
    "            \"label\": np.random.choice([\"👍\", \"👎\"], len(span_ids)),\n",
    "        }\n",
    "    ).set_index(\"context.span_id\")\n",
    "    px.Client().log_evaluations(SpanEvaluations(name, df))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a2f2ac17",
   "metadata": {},
   "source": [
    "# Genarate Sessions (For Development)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "caedf0f3bc9b454c",
   "metadata": {},
   "outputs": [],
   "source": [
    "session_count = randint(5, 10)\n",
    "tree_complexity = 4  # set to 0 for single span under root\n",
    "prob_drop_root = 0.0  # probability that a root span gets dropped\n",
    "\n",
    "\n",
    "def simulate_openai():\n",
    "    user_id = gen_user_id() if random() < 0.9 else \" \"\n",
    "    session_id = gen_session_id()\n",
    "    client = openai.Client(api_key=\"sk-\")\n",
    "    model = \"gpt-4o-mini\"\n",
    "    encoding = encoding_for_model(model)\n",
    "    messages = np.concatenate(convo.sample(randint(1, 10)).values)\n",
    "    counts = [len(encoding.encode(m[\"content\"])) for m in messages]\n",
    "    openai_mock = OpenAIMock()\n",
    "    tracer = tracer_provider.get_tracer(__name__)\n",
    "    with openai_mock.router:\n",
    "        for i in range(1, len(messages), 2):\n",
    "            openai_mock.chat.completions.create.response = dict(\n",
    "                choices=[dict(index=0, finish_reason=\"stop\", message=messages[i])],\n",
    "                usage=dict(\n",
    "                    prompt_tokens=sum(counts[:i]),\n",
    "                    completion_tokens=counts[i],\n",
    "                    total_tokens=sum(counts[: i + 1]),\n",
    "                ),\n",
    "            )\n",
    "            with ExitStack() as stack:\n",
    "                attributes = {\n",
    "                    \"input.value\": messages[i - 1][\"content\"],\n",
    "                    \"output.value\": messages[i][\"content\"],\n",
    "                }\n",
    "                if random() < 0.5:\n",
    "                    attributes[\"session.id\"] = session_id\n",
    "                    attributes[\"user.id\"] = user_id\n",
    "                else:\n",
    "                    stack.enter_context(using_session(session_id))\n",
    "                    stack.enter_context(using_user(user_id))\n",
    "                root = stack.enter_context(\n",
    "                    tracer.start_as_current_span(\n",
    "                        \"root\",\n",
    "                        attributes=attributes,\n",
    "                        end_on_exit=False,\n",
    "                    )\n",
    "                )\n",
    "                with trace_tree(tracer, tree_complexity):\n",
    "                    client.chat.completions.create(model=model, messages=messages[:i])\n",
    "            root.set_status(rand_status_code())\n",
    "            root.end(int(fake.future_datetime(\"+5s\").timestamp() * 10**9))\n",
    "\n",
    "\n",
    "OpenAIInstrumentor().instrument(tracer_provider=tracer_provider)\n",
    "try:\n",
    "    for _ in range(session_count):\n",
    "        simulate_openai()\n",
    "finally:\n",
    "    OpenAIInstrumentor().uninstrument()\n",
    "spans = export_spans(prob_drop_root)\n",
    "\n",
    "# Annotate root spans\n",
    "root_span_ids = pd.Series(\n",
    "    [span.context.span_id.to_bytes(8, \"big\").hex() for span in spans if span.parent is None]\n",
    ")\n",
    "for name in \"ABC\":\n",
    "    span_ids = root_span_ids.sample(frac=0.5)\n",
    "    df = pd.DataFrame(\n",
    "        {\n",
    "            \"context.span_id\": span_ids,\n",
    "            \"score\": np.random.rand(len(span_ids)),\n",
    "            \"label\": np.random.choice([\"👍\", \"👎\"], len(span_ids)),\n",
    "            \"explanation\": [fake.paragraph(10) for _ in range(len(span_ids))],\n",
    "        }\n",
    "    ).set_index(\"context.span_id\")\n",
    "    px.Client().log_evaluations(SpanEvaluations(name, df))"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
